import sys
import time
import threading
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional
import numpy as np
from OpenGL.GL import *
from OpenGL.GLUT import *
from OpenGL.GL.shaders import compileProgram, compileShader
import pywifi
import tkinter as tk
from tkinter import ttk


@dataclass
class VisualizationConfig:
    """Configuration parameters for the visualization."""
    max_particles: int = 10000
    max_scaffold: int = 1000
    stable_threshold: float = 3.0
    stable_frames: int = 30
    default_amplitude: float = 1.0
    default_noise: float = 0.3
    default_persistence: float = 0.92
    default_morph: float = 0.0


@dataclass
class CameraState:
    """Camera state management."""
    angle_x: float = 30.0
    angle_y: float = 30.0
    distance: float = 4.5
    sensitivity: float = 0.5
    zoom_sensitivity: float = 0.1
    min_distance: float = 1.0
    max_distance: float = 10.0
    
    def reset(self):
        """Reset camera to default position."""
        self.angle_x = 30.0
        self.angle_y = 30.0
        self.distance = 4.5
    
    def zoom(self, direction: int):
        """Zoom in or out based on direction."""
        factor = 1 - self.zoom_sensitivity if direction > 0 else 1 + self.zoom_sensitivity
        self.distance = np.clip(self.distance * factor, self.min_distance, self.max_distance)
    
    def rotate(self, dx: float, dy: float):
        """Rotate camera based on mouse delta."""
        self.angle_x += dx * self.sensitivity
        self.angle_y = np.clip(self.angle_y + dy * self.sensitivity, -90, 90)


class ParticleSystem:
    """Manages particle positions, velocities, and colors."""
    
    def __init__(self, config: VisualizationConfig):
        self.config = config
        self.positions = np.random.uniform(-1, 1, (config.max_particles, 3)).astype(np.float32)
        self.velocities = np.random.uniform(-0.002, 0.002, (config.max_particles, 3)).astype(np.float32)
        self.colors = np.zeros((config.max_particles, 3), dtype=np.float32)
        self.trails = np.zeros((config.max_particles, 3), dtype=np.float32)
        self.active_count = 0
    
    def update(self, amplitude: float, noise: float, persistence: float):
        """Update particle positions and trails."""
        noise_delta = np.random.normal(0, noise, self.positions.shape).astype(np.float32)
        self.positions += self.velocities + noise_delta
        np.clip(self.positions, -1, 1, out=self.positions)
        self.positions *= amplitude
        
        # Update trails with persistence
        self.trails = persistence * self.trails + (1 - persistence) * self.positions


class ScaffoldSystem:
    """Manages stable network positions (scaffold points)."""
    
    def __init__(self, config: VisualizationConfig):
        self.config = config
        self.positions = np.zeros((config.max_scaffold, 3), dtype=np.float32)
        self.colors = np.ones((config.max_scaffold, 3), dtype=np.float32) * 0.5
        self.count = 0
    
    def add_point(self, position: np.ndarray, color: Optional[np.ndarray] = None):
        """Add a new scaffold point."""
        if self.count < self.config.max_scaffold:
            self.positions[self.count] = position
            if color is not None:
                self.colors[self.count] = color
            else:
                self.colors[self.count] = [0.8, 0.8, 1.0]
            self.count += 1
    
    def clear(self):
        """Clear all scaffold points."""
        self.count = 0
        print("[Scaffold] Cleared all scaffold points")


class WiFiScanner:
    """Handles WiFi scanning in a separate thread."""
    
    def __init__(self):
        self.networks: List[Tuple[str, float]] = []
        self.stability_tracker: Dict[str, List[float]] = {}
        self.ssid_to_index: Dict[str, int] = {}
        self.next_index = 0
        self.running = True
        self._init_wifi()
        self._start_scanning()
    
    def _init_wifi(self):
        """Initialize WiFi interface."""
        try:
            self.wifi = pywifi.PyWiFi()
            self.iface = self.wifi.interfaces()[0]
            print("[WiFi] Interface initialized successfully")
        except Exception as e:
            print(f"[WiFi Error] Failed to initialize: {e}")
            self.iface = None
    
    def _start_scanning(self):
        """Start the WiFi scanning thread."""
        def scan_loop():
            while self.running:
                try:
                    if self.iface:
                        self.iface.scan()
                        time.sleep(1)
                        results = self.iface.scan_results()
                        self.networks = sorted([(r.ssid, r.signal) for r in results if r.ssid], 
                                             key=lambda x: x[0])
                        if len(self.networks) > 0:
                            print(f"[WiFi] Found {len(self.networks)} networks")
                    else:
                        time.sleep(5)
                except Exception as e:
                    print(f"[WiFi Error] Scan failed: {e}")
                    time.sleep(2)
        
        threading.Thread(target=scan_loop, daemon=True).start()
    
    def get_stable_networks(self, config: VisualizationConfig) -> List[Tuple[str, float, np.ndarray]]:
        """Return networks that have been stable for enough frames."""
        stable_networks = []
        
        for ssid, signal in self.networks:
            if ssid not in self.ssid_to_index:
                self.ssid_to_index[ssid] = self.next_index
                self.next_index += 1
            
            # Track stability
            if ssid not in self.stability_tracker:
                self.stability_tracker[ssid] = []
            
            self.stability_tracker[ssid].append(signal)
            
            if len(self.stability_tracker[ssid]) > config.stable_frames:
                hist = self.stability_tracker[ssid][-config.stable_frames:]
                if np.var(hist) < config.stable_threshold:
                    mean_signal = np.mean(hist)
                    index = self.ssid_to_index[ssid]
                    stable_networks.append((ssid, mean_signal, index))
                    self.stability_tracker[ssid] = []  # Reset after marking as stable
        
        return stable_networks
    
    def stop(self):
        """Stop scanning."""
        self.running = False


class ControlPanel:
    """GUI control panel for visualization parameters."""
    
    def __init__(self, config: VisualizationConfig):
        self.config = config
        self.root: Optional[tk.Tk] = None
        self.amp_var: Optional[tk.DoubleVar] = None
        self.noise_var: Optional[tk.DoubleVar] = None
        self.morph_var: Optional[tk.DoubleVar] = None
        self.persistence_var: Optional[tk.DoubleVar] = None
        self.ready = False
        self.rendering_paused = False
        self._init_gui()
    
    def _init_gui(self):
        """Initialize GUI in separate thread."""
        def setup_gui():
            try:
                self.root = tk.Tk()
                self.root.title("WiFi Visualizer Controls")
                self.root.geometry("400x300")
                
                # Create parameter sliders
                self.amp_var = self._create_slider("Amplitude", 0.1, 2.0, self.config.default_amplitude)
                self.noise_var = self._create_slider("Noise", 0.0, 1.0, self.config.default_noise)
                self.morph_var = self._create_slider("Polar Transform", 0.0, 1.0, self.config.default_morph)
                self.persistence_var = self._create_slider("Trail Persistence", 0.5, 0.99, self.config.default_persistence)
                
                # Control buttons
                button_frame = ttk.Frame(self.root)
                button_frame.pack(fill='x', padx=10, pady=10)
                
                ttk.Button(button_frame, text="Pause/Resume", 
                          command=self._toggle_pause).pack(side='left', padx=5)
                ttk.Button(button_frame, text="Clear Scaffold", 
                          command=self._clear_scaffold).pack(side='left', padx=5)
                ttk.Button(button_frame, text="Reset View", 
                          command=self._reset_view).pack(side='left', padx=5)
                
                self.ready = True
                print("[GUI] Control panel initialized successfully")
                self.root.mainloop()
                
            except Exception as e:
                print(f"[GUI Error] Failed to initialize: {e}")
        
        threading.Thread(target=setup_gui, daemon=True).start()
    
    def _create_slider(self, label: str, from_: float, to: float, initial: float) -> tk.DoubleVar:
        """Create a labeled slider."""
        frame = ttk.LabelFrame(self.root, text=label, padding=10)
        frame.pack(fill='x', padx=10, pady=5)
        
        var = tk.DoubleVar(value=initial)
        scale = ttk.Scale(frame, from_=from_, to=to, orient='horizontal', variable=var)
        scale.pack(fill='x')
        
        # Value display
        value_label = ttk.Label(frame, text=f"{initial:.2f}")
        value_label.pack()
        
        def update_label(*args):
            value_label.config(text=f"{var.get():.2f}")
        
        var.trace('w', update_label)
        return var
    
    def _toggle_pause(self):
        """Toggle rendering pause."""
        self.rendering_paused = not self.rendering_paused
        print(f"[GUI] Rendering {'paused' if self.rendering_paused else 'resumed'}")
    
    def _clear_scaffold(self):
        """Signal to clear scaffold points."""
        # This will be handled by the main visualizer
        pass
    
    def _reset_view(self):
        """Signal to reset camera view."""
        # This will be handled by the main visualizer
        pass
    
    def get_parameters(self) -> Dict[str, float]:
        """Get current parameter values."""
        if not self.ready:
            return {
                'amplitude': self.config.default_amplitude,
                'noise': self.config.default_noise,
                'morph': self.config.default_morph,
                'persistence': self.config.default_persistence
            }
        
        try:
            return {
                'amplitude': self.amp_var.get(),
                'noise': self.noise_var.get(),
                'morph': self.morph_var.get(),
                'persistence': self.persistence_var.get()
            }
        except:
            return {
                'amplitude': self.config.default_amplitude,
                'noise': self.config.default_noise,
                'morph': self.config.default_morph,
                'persistence': self.config.default_persistence
            }


class WiFiVisualizer:
    """Main visualizer class that orchestrates all components."""
    
    VERTEX_SHADER = """
    #version 330 core
    layout(location = 0) in vec3 position;
    layout(location = 1) in vec3 color;
    
    uniform mat4 mvp_matrix;
    uniform float morph_factor;
    uniform float radius_time;
    
    out vec3 vertex_color;
    
    void main() {
        vec3 pos = position;
        
        // Apply polar transformation based on morph factor
        if (morph_factor > 0.01) {
            float r = length(pos.xy) * (1.0 + radius_time * 0.5);
            float theta = atan(pos.y, pos.x);
            pos.x = r * cos(theta) * morph_factor + pos.x * (1.0 - morph_factor);
            pos.y = r * sin(theta) * morph_factor + pos.y * (1.0 - morph_factor);
        }
        
        gl_Position = mvp_matrix * vec4(pos, 1.0);
        gl_PointSize = 6.0;
        vertex_color = color;
    }
    """
    
    FRAGMENT_SHADER = """
    #version 330 core
    in vec3 vertex_color;
    out vec4 frag_color;
    
    void main() {
        // Create circular points
        vec2 coord = gl_PointCoord - vec2(0.5);
        if (length(coord) > 0.5) discard;
        
        float alpha = 1.0 - length(coord) * 2.0;
        frag_color = vec4(vertex_color, alpha * 0.8);
    }
    """
    
    def __init__(self):
        self.config = VisualizationConfig()
        self.camera = CameraState()
        self.particles = ParticleSystem(self.config)
        self.scaffold = ScaffoldSystem(self.config)
        self.wifi_scanner = WiFiScanner()
        self.control_panel = ControlPanel(self.config)
        
        # OpenGL resources
        self.shader_program = None
        self.particle_vao = None
        self.scaffold_vao = None
        self.mvp_matrix = np.eye(4, dtype=np.float32)
        
        # Animation state
        self.radius_time = 0.0
        self.last_time = time.time()
        self.frame_times = []
        
        # Mouse state
        self.mouse_last_pos = None
        self.is_orbiting = False
        
        # Control flags
        self.clear_scaffold_requested = False
        self.reset_view_requested = False
    
    def _compile_shaders(self):
        """Compile and link shader program."""
        vertex_shader = compileShader(self.VERTEX_SHADER, GL_VERTEX_SHADER)
        fragment_shader = compileShader(self.FRAGMENT_SHADER, GL_FRAGMENT_SHADER)
        self.shader_program = compileProgram(vertex_shader, fragment_shader)
    
    def _setup_particle_vao(self):
        """Setup particle vertex array object."""
        self.particle_vao = glGenVertexArrays(1)
        glBindVertexArray(self.particle_vao)
        
        # Position buffer
        pos_buffer = glGenBuffers(1)
        glBindBuffer(GL_ARRAY_BUFFER, pos_buffer)
        glBufferData(GL_ARRAY_BUFFER, self.particles.positions.nbytes, 
                    self.particles.positions, GL_DYNAMIC_DRAW)
        glVertexAttribPointer(0, 3, GL_FLOAT, GL_FALSE, 0, None)
        glEnableVertexAttribArray(0)
        
        # Color buffer
        color_buffer = glGenBuffers(1)
        glBindBuffer(GL_ARRAY_BUFFER, color_buffer)
        glBufferData(GL_ARRAY_BUFFER, self.particles.colors.nbytes, 
                    self.particles.colors, GL_DYNAMIC_DRAW)
        glVertexAttribPointer(1, 3, GL_FLOAT, GL_FALSE, 0, None)
        glEnableVertexAttribArray(1)
        
        self.particle_buffers = (pos_buffer, color_buffer)
    
    def _setup_scaffold_vao(self):
        """Setup scaffold vertex array object."""
        self.scaffold_vao = glGenVertexArrays(1)
        glBindVertexArray(self.scaffold_vao)
        
        # Position buffer
        pos_buffer = glGenBuffers(1)
        glBindBuffer(GL_ARRAY_BUFFER, pos_buffer)
        glBufferData(GL_ARRAY_BUFFER, self.scaffold.positions.nbytes, 
                    self.scaffold.positions, GL_DYNAMIC_DRAW)
        glVertexAttribPointer(0, 3, GL_FLOAT, GL_FALSE, 0, None)
        glEnableVertexAttribArray(0)
        
        # Color buffer
        color_buffer = glGenBuffers(1)
        glBindBuffer(GL_ARRAY_BUFFER, color_buffer)
        glBufferData(GL_ARRAY_BUFFER, self.scaffold.colors.nbytes, 
                    self.scaffold.colors, GL_DYNAMIC_DRAW)
        glVertexAttribPointer(1, 3, GL_FLOAT, GL_FALSE, 0, None)
        glEnableVertexAttribArray(1)
        
        self.scaffold_buffers = (pos_buffer, color_buffer)
    
    def _update_mvp_matrix(self):
        """Update model-view-projection matrix."""
        # Create view matrix
        view = np.eye(4, dtype=np.float32)
        
        # Apply camera transformations
        glMatrixMode(GL_MODELVIEW)
        glLoadIdentity()
        glTranslatef(0, 0, -self.camera.distance)
        glRotatef(self.camera.angle_y, 1, 0, 0)
        glRotatef(self.camera.angle_x, 0, 1, 0)
        
        # Get the current modelview matrix
        modelview = glGetFloatv(GL_MODELVIEW_MATRIX)
        
        # Get projection matrix
        glMatrixMode(GL_PROJECTION)
        projection = glGetFloatv(GL_PROJECTION_MATRIX)
        
        # Compute MVP matrix
        self.mvp_matrix = (projection @ modelview).astype(np.float32)
    
    def init_opengl(self):
        """Initialize OpenGL resources."""
        self._compile_shaders()
        self._setup_particle_vao()
        self._setup_scaffold_vao()
        
        # OpenGL state
        glEnable(GL_PROGRAM_POINT_SIZE)
        glEnable(GL_BLEND)
        glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
        glEnable(GL_DEPTH_TEST)
        glClearColor(0.02, 0.02, 0.05, 1.0)
        
        print("[OpenGL] Initialization complete")
    
    def update_simulation(self):
        """Update the simulation state."""
        if self.control_panel.rendering_paused:
            return
        
        current_time = time.time()
        dt = current_time - self.last_time
        self.last_time = current_time
        
        # Update animation time
        self.radius_time += dt * 2.0
        if self.radius_time > np.pi * 2:
            self.radius_time = 0.0
        
        # Get current parameters
        params = self.control_panel.get_parameters()
        
        # Update particles
        self.particles.update(params['amplitude'], params['noise'], params['persistence'])
        
        # Update particle colors based on WiFi networks
        self.particles.active_count = 0
        for ssid, signal in self.wifi_scanner.networks:
            if ssid not in self.wifi_scanner.ssid_to_index:
                self.wifi_scanner.ssid_to_index[ssid] = self.wifi_scanner.next_index
                self.wifi_scanner.next_index += 1
            
            index = self.wifi_scanner.ssid_to_index[ssid]
            if index >= self.config.max_particles:
                continue
            
            # Update particle color based on signal strength
            strength = np.clip((signal + 100) / 50.0, 0, 1)
            self.particles.colors[index] = [1 - strength, strength, 0.3]
            self.particles.active_count = max(self.particles.active_count, index + 1)
        
        # Check for stable networks and add to scaffold
        stable_networks = self.wifi_scanner.get_stable_networks(self.config)
        for ssid, mean_signal, index in stable_networks:
            if index < self.config.max_particles:
                strength = np.clip((mean_signal + 100) / 50.0, 0, 1)
                position = self.particles.positions[index] * (0.5 + strength * 1.0)
                self.scaffold.add_point(position, [0.8, 0.8 + strength * 0.2, 1.0])
    
    def render(self):
        """Render the current frame."""
        start_time = time.time()
        
        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)
        
        # Handle control requests
        if self.clear_scaffold_requested:
            self.scaffold.clear()
            self.clear_scaffold_requested = False
        
        if self.reset_view_requested:
            self.camera.reset()
            self.reset_view_requested = False
        
        # Update MVP matrix
        self._update_mvp_matrix()
        
        # Use shader program
        glUseProgram(self.shader_program)
        
        # Set uniforms
        params = self.control_panel.get_parameters()
        mvp_loc = glGetUniformLocation(self.shader_program, "mvp_matrix")
        morph_loc = glGetUniformLocation(self.shader_program, "morph_factor")
        radius_loc = glGetUniformLocation(self.shader_program, "radius_time")
        
        glUniformMatrix4fv(mvp_loc, 1, GL_FALSE, self.mvp_matrix)
        glUniform1f(morph_loc, params['morph'])
        glUniform1f(radius_loc, self.radius_time)
        
        # Update and render particles
        if self.particles.active_count > 0:
            glBindBuffer(GL_ARRAY_BUFFER, self.particle_buffers[0])
            glBufferSubData(GL_ARRAY_BUFFER, 0, self.particles.positions.nbytes, self.particles.positions)
            
            glBindBuffer(GL_ARRAY_BUFFER, self.particle_buffers[1])
            glBufferSubData(GL_ARRAY_BUFFER, 0, self.particles.colors.nbytes, self.particles.colors)
            
            glBindVertexArray(self.particle_vao)
            glDrawArrays(GL_POINTS, 0, self.particles.active_count)
        
        # Render scaffold
        if self.scaffold.count > 0:
            glBindBuffer(GL_ARRAY_BUFFER, self.scaffold_buffers[0])
            glBufferSubData(GL_ARRAY_BUFFER, 0, 
                          self.scaffold.positions[:self.scaffold.count].nbytes,
                          self.scaffold.positions[:self.scaffold.count])
            
            glBindBuffer(GL_ARRAY_BUFFER, self.scaffold_buffers[1])
            glBufferSubData(GL_ARRAY_BUFFER, 0,
                          self.scaffold.colors[:self.scaffold.count].nbytes,
                          self.scaffold.colors[:self.scaffold.count])
            
            glBindVertexArray(self.scaffold_vao)
            glPointSize(12.0)
            glDrawArrays(GL_POINTS, 0, self.scaffold.count)
            glPointSize(6.0)
        
        glutSwapBuffers()
        
        # Performance tracking
        frame_time = time.time() - start_time
        self.frame_times.append(frame_time)
        if len(self.frame_times) > 60:
            self.frame_times.pop(0)
            avg_frame_time = np.mean(self.frame_times)
            if len(self.frame_times) % 60 == 0:  # Print every 60 frames
                print(f"[Performance] Avg frame time: {avg_frame_time*1000:.2f}ms "
                      f"({1.0/avg_frame_time:.1f} FPS)")
    
    def handle_keyboard(self, key: bytes, x: int, y: int):
        """Handle keyboard input."""
        if key == b'p':
            self.control_panel._toggle_pause()
        elif key == b's':
            self.clear_scaffold_requested = True
        elif key == b'r':
            self.reset_view_requested = True
        elif key in [b'+', b'=']:
            self.camera.zoom(1)
        elif key == b'-':
            self.camera.zoom(-1)
        elif key == b'\x1b':  # Escape key
            self.cleanup()
            sys.exit(0)
    
    def handle_mouse(self, button: int, state: int, x: int, y: int):
        """Handle mouse clicks."""
        if button == GLUT_LEFT_BUTTON:
            self.is_orbiting = (state == GLUT_DOWN)
            self.mouse_last_pos = (x, y) if self.is_orbiting else None
        elif button == GLUT_RIGHT_BUTTON and state == GLUT_DOWN:
            self.reset_view_requested = True
    
    def handle_mouse_motion(self, x: int, y: int):
        """Handle mouse motion for camera orbiting."""
        if self.is_orbiting and self.mouse_last_pos:
            dx = x - self.mouse_last_pos[0]
            dy = y - self.mouse_last_pos[1]
            self.camera.rotate(dx, dy)
            self.mouse_last_pos = (x, y)
    
    def handle_mouse_wheel(self, button: int, direction: int, x: int, y: int):
        """Handle mouse wheel for zooming."""
        self.camera.zoom(direction)
    
    def idle(self):
        """Idle callback for continuous updates."""
        self.update_simulation()
        glutPostRedisplay()
    
    def cleanup(self):
        """Clean up resources."""
        print("[Cleanup] Shutting down...")
        self.wifi_scanner.stop()
        if self.control_panel.root:
            try:
                self.control_panel.root.quit()
            except:
                pass


def main():
    """Main entry point."""
    print("WiFi Room Visualizer - Enhanced Edition")
    print("Controls:")
    print("  Mouse: Left drag to orbit, right click to reset view")
    print("  Keyboard: P=pause, S=clear scaffold, R=reset view, +/-=zoom, ESC=exit")
    
    # Initialize GLUT
    glutInit(sys.argv)
    glutInitDisplayMode(GLUT_RGBA | GLUT_DOUBLE | GLUT_DEPTH)
    glutInitWindowSize(1280, 720)
    glutCreateWindow(b"WiFi Room Visualizer - Enhanced")
    
    # Create and initialize visualizer
    visualizer = WiFiVisualizer()
    visualizer.init_opengl()
    
    # Setup perspective projection
    glMatrixMode(GL_PROJECTION)
    glLoadIdentity()
    gluPerspective(60.0, 1280.0/720.0, 0.1, 100.0)
    glMatrixMode(GL_MODELVIEW)
    
    # Register callbacks
    glutDisplayFunc(visualizer.render)
    glutIdleFunc(visualizer.idle)
    glutKeyboardFunc(visualizer.handle_keyboard)
    glutMouseFunc(visualizer.handle_mouse)
    glutMotionFunc(visualizer.handle_mouse_motion)
    
    try:
        glutMouseWheelFunc(visualizer.handle_mouse_wheel)
    except:
        print("[Warning] Mouse wheel not supported on this system")
    
    print("[Main] Starting visualization...")
    try:
        glutMainLoop()
    except KeyboardInterrupt:
        print("\n[Main] Interrupted by user")
    finally:
        visualizer.cleanup()


if __name__ == "__main__":
    main()